import warnings
warnings.filterwarnings("ignore")

import torch as T
import torch.nn.functional as F
import numpy as np
import tqdm
import pickle
import gzip
import argparse
import random

parser = argparse.ArgumentParser()

parser.add_argument("--save_path", default='results/tb_r0.pkl.gz', type=str)
parser.add_argument("--device", default='cuda', type=str)

# GFlowNet settings
parser.add_argument("--seed", default='0', type=int)
parser.add_argument("--method", default='flownet', type=str)
parser.add_argument("--loss", default='trajectory_balance', type=str, help="Detailed balance loss or Trajectory balance loss")
parser.add_argument("--pb_type", default='flexible', type=str, help="flexible or uniform")
parser.add_argument("--learning_rate_model", default=1e-3, help="Learning rate for model parameters", type=float)
parser.add_argument("--learning_rate_z", default=1e-1, help="Learning rate for Z", type=float)
parser.add_argument("--mbsize", default=16, help="Minibatch size", type=int)
parser.add_argument("--n_hid", default=256, type=int)
parser.add_argument("--n_layers", default=2, type=int)
parser.add_argument("--n_train_steps", default=1, type=int)
parser.add_argument("--num_empirical_loss", default=200000, type=int,
                    help="Number of samples used to compute the empirical distribution loss")
parser.add_argument('--exp_weight', default=0.0, type=float)

# OT regularizationt
parser.add_argument("--reg_coef", default=1e-1, type=float, help="Coefficient for regularisation term for main objective loss")

# Env
parser.add_argument("--horizon", default=8, type=int)
parser.add_argument('--r', default='0.1', type=float) # R0 = 1e-1, R1 = 1e-2, R2 = 1e-3
parser.add_argument("--ndim", default=4, type=int)

# Make MLP model
def make_mlp(l, act=T.nn.LeakyReLU(), tail=[]):
        return T.nn.Sequential(*(sum(
            [[T.nn.Linear(i, o)] + ([act] if n < len(l)-2 else [])
            for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))

# Main
def main(args):
    # Parse argument
    device = args.device
    horizon = args.horizon
    ndim = args.ndim
    n_hid = args.n_hid
    n_layers = args.n_layers
    bs = args.mbsize
    detailed_balance = (args.loss=="detailed balance")
    uniform_pb = (args.pb_type=="uniform")
    r = args.r

    # Log of the reward function of a state
    def log_reward(x):
        ax = abs(x / (horizon-1) * 2 - 1)
        return ((ax > 0.5).prod(-1) * 0.5 + ((ax < 0.8) * (ax > 0.6)).prod(-1) * 2 + r).log()

    # Encode a state to its one hot version
    def toin(z):
        return T.nn.functional.one_hot(z, horizon).view(z.shape[0],-1).float()
    
    # Compute OT regularisation between two states
    def compute_ot_reg(cur_s, cur_forward, action_to_next_s, next_forward, model): # Return (length, ot_loss)
        # Calculate children of cur_s
        children_of_cur_s_idx = (cur_s!=horizon-1)
        children_of_cur_s = cur_s.unsqueeze(1)+T.diag_embed(children_of_cur_s_idx)
        children_of_cur_s =  children_of_cur_s.long()

        # Prediction from model for transition policies of children of cur_s 
        pred_policies_of_children_of_cur_s = model(toin(children_of_cur_s.reshape(-1,ndim))).reshape(children_of_cur_s.size()[0], ndim, -1)
        
        # Calculate forward and backward probabilities for each child of cur_s
        backward_prob_of_children_of_cur_s = (pred_policies_of_children_of_cur_s[...,ndim+1:2*ndim+1] - 1000000000*(children_of_cur_s==0).float()).softmax(2)
        
        # Caculate backward probabilities from each child of cur_s to cur_s
        backward_prob_from_children_of_cur_s_to_cur_s = T.diagonal(backward_prob_of_children_of_cur_s, dim1 = -2, dim2 = -1)*(cur_forward[...,:ndim]>0)
        
        # Tensor to store temporary ot regularisation and trajectory length 
        ot_reg_temp = T.sum(-cur_forward[...,:ndim]*T.log(backward_prob_from_children_of_cur_s_to_cur_s + 1e-13), 1) - T.log(cur_forward.gather(1,action_to_next_s).reshape(cur_forward.shape[0],) + 1e-13) + T.sum(-next_forward*T.log(next_forward + 1e-13), 1)
    
        return ot_reg_temp

    def compute_ot_reg_at_terminal(cur_s, cur_forward, action_to_next_s, model):
        # Calculate children of cur_s
        children_of_cur_s_idx = (cur_s!=horizon-1)
        children_of_cur_s = cur_s.unsqueeze(1)+T.diag_embed(children_of_cur_s_idx)
        children_of_cur_s =  children_of_cur_s.long()

        # Prediction from model for transition policies of children of cur_s 
        pred_policies_of_children_of_cur_s = model(toin(children_of_cur_s.reshape(-1,ndim))).reshape(children_of_cur_s.size()[0], ndim, -1)
        
        # Calculate forward and backward probabilities for each child of cur_s
        backward_prob_of_children_of_cur_s = (pred_policies_of_children_of_cur_s[...,ndim+1:2*ndim+1] - 1000000000*(children_of_cur_s==0).float()).softmax(2)
        
        # Caculate backward probabilities from each child of cur_s to cur_s
        backward_prob_from_children_of_cur_s_to_cur_s = T.diagonal(backward_prob_of_children_of_cur_s, dim1 = -2, dim2 = -1)*(cur_forward[...,:ndim]>0)
        
        # Tensor to store temporary ot regularisation and trajectory length 
        ot_reg_temp = T.sum(-cur_forward[...,:ndim]*T.log(backward_prob_from_children_of_cur_s_to_cur_s + 1e-13), 1) - T.log(cur_forward.gather(1,action_to_next_s).reshape(cur_forward.shape[0],) + 1e-13) 
        
        return ot_reg_temp


    # Compute log of rewards of all states in the grid according to the chosen reward function
    j = T.zeros((horizon,)*ndim+(ndim,))

    for i in range(ndim):
        jj = T.linspace(0,horizon-1,horizon)
        for _ in range(i): jj = jj.unsqueeze(1)
        j[...,i] = jj

    truelr = log_reward(j)
    print('Log of total reward: ', T.logsumexp(truelr.view(-1), 0))

    # Compute true distribution of the normalized reward of all states
    true_dist = T.softmax(truelr.flatten(), 0).cpu().numpy()

    # Learnable normalized constant
    Z = T.zeros((1,)).to(device)

    # Create NN model for forward and backward policy of a state in GFLowNet
    model = make_mlp([ndim*horizon] + [n_hid] * n_layers + [2*ndim+1]).to(device)
    opt = T.optim.Adam([ {'params':model.parameters(), 'lr':args.learning_rate_model}, {'params':[Z], 'lr':args.learning_rate_z} ])
    Z.requires_grad_()

    print('loss is', 'DB' if detailed_balance else 'TB')
    
    # Variables
    losses = [] # Total loss
    matching_losses = [] # Matching loss (Specifically: Trajectory balance loss)
    reg_losses = [] # Regularisation loss
    zs = [] # Normalised constant
    all_visited_state = [] # Records all visited state
    all_visited = [] # Convert visited state to unique form with histogram to calculate empirical density
    l1log = [] # (number of visited states, l1 error between true density and empirical density)
    kllog = []

    for it in tqdm.trange(args.n_train_steps):
        opt.zero_grad()

        z = T.zeros((bs,ndim), dtype=T.long).to(device) # trajectory sampled
        done = T.full((bs,), False, dtype=T.bool).to(device) # trajectory is done sampling ?
        length = T.zeros((bs,), dtype=T.int).to(device) # number of successive path in each trajectory that has ot loss
        ot_reg = T.zeros((bs,), dtype=T.float).to(device) # ot loss for each trajectory sampled
        action = None # action corresponding to unique transition
        
        # Update matching loss corresponding to log of normalised constant
        if detailed_balance:
            ll_diff = T.zeros((ndim*horizon, bs)).to(device)
        else:
            ll_diff = T.zeros((bs,)).to(device)
            ll_diff += Z
        
        i = 0

        while True:
            if i == 0:
                # Prediction from model for current state
                pred = model(toin(z[~done]))
                edge_mask = T.cat([ (z[~done]==horizon-1).float(), T.zeros(((~done).sum(),1), device=device) ], 1)
                logits = (pred[...,:ndim+1] - 1000000000*edge_mask).log_softmax(1)
            else:
                # Reused calculated results
                pred = next_pred
                edge_mask = next_edge_mask
                logits = next_logits
            
            # Caculate log of backward probability 
            init_edge_mask = (z[~done]== 0).float()
            back_logits = F.log_softmax((0 if uniform_pb else 1)*pred[...,ndim+1:2*ndim+1] - 1000000000*init_edge_mask, 1)
            
            # Update matching loss corresponding to log of backward probability
            if action is not None: 
                if detailed_balance:
                    ll_diff[i-1,~done] -= back_logits.gather(1, action[action!=ndim].unsqueeze(1)).squeeze(1)
                else:
                    ll_diff[~done] -= back_logits.gather(1, action[action!=ndim].unsqueeze(1)).squeeze(1)
            
            # Sample transition action 
            exp_weight= args.exp_weight
            temp = 1
            sample_ins_probs = (1-exp_weight)*(logits/temp).softmax(1) + exp_weight*(1-edge_mask) / (1-edge_mask+0.0000001).sum(1).unsqueeze(1)
            action = sample_ins_probs.multinomial(1)
            
            # Update matching loss corresponding to log of forward probability
            if detailed_balance:
                ll_diff[i,~done] += logits.gather(1, action).squeeze(1)
            else:
                ll_diff[~done] += logits.gather(1, action).squeeze(1)
            
            # Check terminate action
            terminate = (action==ndim).squeeze(1)
            
            # Updating visited states
            for x in z[~done][terminate]: 
                state = (x.cpu()*(horizon**T.arange(ndim))).sum().item()
                all_visited_state.append(state)
                all_visited.append(list(x.cpu().numpy()))
            
            if T.any(terminate):
                temp = (done + 0).bool()
                temp[done] = False
                temp[~done] = T.logical_or(done[~done],terminate)
                ot_reg[temp] += compute_ot_reg_at_terminal(z[~done][terminate], (logits[terminate]).softmax(1), action[terminate], model)

            cur_s = z[~done][~terminate]
            cur_forward = (logits[~terminate]).softmax(1)                                                                                  
            done[~done] |= terminate # Updating status of sampled trajectory

            with T.no_grad():
                z[~done] = z[~done].scatter_add(1, action[~terminate], T.ones(action[~terminate].shape, dtype=T.long, device=device))
                length[~done] += 1

            if T.any(~done):
                # Prediction from model for the next state
                next_pred = model(toin(z[~done]))

                next_edge_mask = T.cat([(z[~done]==horizon-1).float(), T.zeros((T.sum(~done),1), device=device) ], 1)
                next_pred_with_mask = next_pred[...,:ndim+1] - 1000000000*next_edge_mask

                # Caculate logits and forward probability
                next_logits = F.log_softmax(next_pred_with_mask, 1)
                next_forward = (T.exp(next_logits) != 0) *F.softmax(next_pred_with_mask, 1)
                
                # Caculate ot loss for each path just sampled
                ot_reg_temp = compute_ot_reg(cur_s, cur_forward, action[~terminate], next_forward, model)    
                ot_reg[~done] += ot_reg_temp
            else:
                break
            i += 1

        lens = z.sum(1)+1
        
        # Matching loss corresponding to log of reward of each sampled trajectory
        if not detailed_balance:
            lr = log_reward(z.float())
            ll_diff -= lr

        # Caculating matching loss
        matching_loss = T.sum(ll_diff**2)/(lens.sum() if detailed_balance else bs)
        
        # Caculating OT regularisation loss 
        reg_loss = T.sum(ot_reg)/(bs)

        # Calculating total loss
        loss = matching_loss + args.reg_coef*reg_loss

        loss.backward()

        opt.step()

        losses.append(loss.item())
        matching_losses.append(matching_loss.item())
        reg_losses.append(reg_loss.item())
        zs.append(Z.item())

        if it%100==0: 
            print('Loss =', np.array(losses[-100:]).mean(), 'Z =', Z.item())
            print('Matching loss =', np.array(matching_losses[-100:]).mean(), 'Regularisation loss =', np.array(reg_losses[-100:]).mean(), "\n")
            
            # Calculate empirical density 
            emp_dist = np.bincount(all_visited_state[-args.num_empirical_loss:], minlength=len(true_dist)).astype(float)
            emp_dist /= emp_dist.sum()

            # Calculate L1 error
            l1 = np.abs(true_dist-emp_dist).mean()
            l1log.append((len(all_visited), l1))
            print('L1 =', l1)

            kl = -(true_dist * np.log((emp_dist / true_dist)+1e-13)).sum()
            print('KL =', kl)
            kllog.append((len(all_visited), kl))

    pickle.dump(
            {'losses': np.float32(losses),
            'matching_losses': np.float32(matching_losses),
            'reg_losses': np.float32(reg_losses),
            'zs': np.float32(zs),
            'params': [i.data.to('cpu').numpy() for i in model.parameters()],
            'visited': np.int8(all_visited),
            'emp_dist_loss': l1log,
            'state_dict': model.state_dict(),
            'kl': kllog,
            'args':args},
            gzip.open(args.save_path, 'wb'))

if __name__ == '__main__':
    args = parser.parse_args()
    random.seed(args.seed)
    T.manual_seed(args.seed)
    np.random.seed(args.seed)
    # T.set_num_threads(16)

    main(args)
